package com.xiaofan.udftest

import com.xiaofan.apitest.source.SensorReading
import org.apache.flink.streaming.api.TimeCharacteristic
import org.apache.flink.streaming.api.functions.timestamps.BoundedOutOfOrdernessTimestampExtractor
import org.apache.flink.streaming.api.scala._
import org.apache.flink.streaming.api.windowing.time.Time
import org.apache.flink.table.api.bridge.scala.StreamTableEnvironment
import org.apache.flink.table.api.{Table, _}
import org.apache.flink.table.functions.AggregateFunction
import org.apache.flink.types.Row

object AggFunctionTest {
  def main(args: Array[String]): Unit = {

    val env: StreamExecutionEnvironment = StreamExecutionEnvironment.getExecutionEnvironment
    env.setParallelism(1)
    env.setStreamTimeCharacteristic(TimeCharacteristic.EventTime)

    val tabEnv: StreamTableEnvironment = StreamTableEnvironment.create(env)

    val inputPath = "D:\\big-data\\code\\FlinkTutorial\\src\\main\\resources\\sensor.txt"
    val inputStream: DataStream[String] = env.readTextFile(inputPath)

    val dataStream: DataStream[SensorReading] = inputStream.map(
      data => {
        val arr: Array[String] = data.split(",")
        SensorReading(arr(0), arr(1).toLong, arr(2).toDouble)
      }
    ).assignTimestampsAndWatermarks(new BoundedOutOfOrdernessTimestampExtractor[SensorReading](Time.milliseconds(3)) {
      override def extractTimestamp(element: SensorReading) = element.timestamp * 1000L
    })

    val sensorTable: Table = tabEnv.fromDataStream(dataStream, $"id", $"temperature", $"timestamp".rowtime() as "ts")

    // table api
    val avgTemp = new AvgTemp
    tabEnv.createTemporaryFunction("avgTemp", avgTemp)

    val resultTable: Table = sensorTable
      .groupBy($"id")
      .aggregate(avgTemp($"temperature") as "avgTemp")
      .select($"id", $"avgTemp")

    // sql
    tabEnv.createTemporaryView("sensor", sensorTable)
    val sqlTableResult: Table = tabEnv.sqlQuery(
      """
        |select id, avgTemp(temperature)
        |from
        |sensor
        |group by id
        |""".stripMargin)

    tabEnv.toRetractStream[Row](sqlTableResult).print("table")


    env.execute("function test")


  }
}

case class AvgTempAcc(var sum: Double = 0.0, var count: Int = 0)

/**
 * 自定义聚合函数，求每个传感器的平均温度值， 保存状态（tempSum, tempCount）
 */
class AvgTemp extends AggregateFunction[Double, AvgTempAcc] {

  def accumulate(accumulator: AvgTempAcc, temp: Double): Unit = {
    accumulator.sum += temp
    accumulator.count += 1
  }

  override def getValue(accumulator: AvgTempAcc): Double = accumulator.sum / accumulator.count

  override def createAccumulator(): AvgTempAcc = AvgTempAcc()
}

